{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本分类实例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step1 导入相关包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.21-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step2 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>review</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7761</th>\n",
       "      <td>0</td>\n",
       "      <td>尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7762</th>\n",
       "      <td>0</td>\n",
       "      <td>盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7763</th>\n",
       "      <td>0</td>\n",
       "      <td>看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7764</th>\n",
       "      <td>0</td>\n",
       "      <td>我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7765</th>\n",
       "      <td>0</td>\n",
       "      <td>说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>7766 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                             review\n",
       "0         1  距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...\n",
       "1         1                       商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!\n",
       "2         1         早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。\n",
       "3         1  宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...\n",
       "4         1               CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风\n",
       "...     ...                                                ...\n",
       "7761      0  尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...\n",
       "7762      0  盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...\n",
       "7763      0  看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...\n",
       "7764      0  我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...\n",
       "7765      0  说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...\n",
       "\n",
       "[7766 rows x 2 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "pretrain_model_dir = \"/data/models/huggingface/rbt3\"\n",
    "save_dir = '/data/logs/data_parallel_origin'\n",
    "data = pd.read_csv(\"./ChnSentiCorp_htl_all.csv\")\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>review</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7761</th>\n",
       "      <td>0</td>\n",
       "      <td>尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7762</th>\n",
       "      <td>0</td>\n",
       "      <td>盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7763</th>\n",
       "      <td>0</td>\n",
       "      <td>看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7764</th>\n",
       "      <td>0</td>\n",
       "      <td>我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7765</th>\n",
       "      <td>0</td>\n",
       "      <td>说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>7765 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                             review\n",
       "0         1  距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...\n",
       "1         1                       商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!\n",
       "2         1         早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。\n",
       "3         1  宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...\n",
       "4         1               CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风\n",
       "...     ...                                                ...\n",
       "7761      0  尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...\n",
       "7762      0  盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...\n",
       "7763      0  看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...\n",
       "7764      0  我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...\n",
       "7765      0  说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...\n",
       "\n",
       "[7765 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = data.dropna()\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step3 创建Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        self.data = pd.read_csv(\"./ChnSentiCorp_htl_all.csv\")\n",
    "        self.data = self.data.dropna()\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.data.iloc[index][\"review\"], self.data.iloc[index][\"label\"]\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较为简单.', 1)\n",
      "('商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!', 1)\n",
      "('早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。', 1)\n",
      "('宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小，但加上低价位因素，还是无超所值的；环境不错，就在小胡同内，安静整洁，暖气好足-_-||。。。呵还有一大优势就是从宾馆出发，步行不到十分钟就可以到梅兰芳故居等等，京味小胡同，北海距离好近呢。总之，不错。推荐给节约消费的自助游朋友~比较划算，附近特色小吃很多~', 1)\n",
      "('CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风', 1)\n"
     ]
    }
   ],
   "source": [
    "dataset = MyDataset()\n",
    "for i in range(5):\n",
    "    print(dataset[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step4 划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6989, 776)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.utils.data import random_split\n",
    "\n",
    "\n",
    "trainset, validset = random_split(dataset, lengths=[0.9, 0.1])\n",
    "len(trainset), len(validset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('五一节去的，住的是家庭房，别的倒还可以，不过不知道为什么房间里蚊子特别多，全家被蚊子咬的伤痕累累，亲手消灭了3只蚊子，然后仔细察看了一下墙壁和家具，有好多陈年的血迹，估计是前几批住客消灭蚊子后留下的。在杭州户外两天都没有被蚊子咬，在这家宾馆一晚上咬得倒是挺多的。补充点评2008年5月22日：总体来说去了两次杭州都是住这家宾馆的，感觉挺好的，第二次入住，还因为是老顾客送了一份水果，感觉很温馨。谢谢了！宾馆反馈2008年5月22日：尊敬的宾客,感谢您的点评.非常抱歉,近期酒店已安排专业的杀虫公司进行灭蝇防蚊工作,房务中心也准备了灭蚊片供客人使用,希望在您下次入住时能有更好的环境.期待着您的再次光临!', 1)\n",
      "('总体感觉很不错,对得起这个价格.就是隔间太差,晚上旁边房间看电视声音听得很清楚.其它什么都好.', 1)\n",
      "('因为未婚夫在美国的关系，他经常要美国、中国两边跑，每次都选择过境香港，所以给了我最好的理由去香港。在香港入住过不同的酒店，这家酒店给我的感觉最好。服务和环境都没有话可说，最感动的是，退房时走得匆忙，一条新买的裙子给忘在酒店房间了，直到回到家收拾行李才发现裙子不见了，当时已经是凌晨2点了，我抱着试试看的心态给酒店去了电话，结果接电话的男士很礼貌的告诉我裙子他们有收到。谢天谢地，经过沟通，以信用卡作邮费支付担保，我的裙子终于在失散一周之后回到了我的身边。下次未婚夫回国，我会强烈推荐他入住这家酒店，我想我以后去香港都会入住这家酒店', 1)\n",
      "('服务很周到，态度也很好，并且为我们升级到行政楼层。房间很干净，设施也很齐全，大床很大，睡得很爽，晚上还送了绿豆汤，很好喝，市内电话是免费的，很方便我们打114查询杭州美食。早餐品种很多，口味也很好，服务也很周到，会有服务员主动询问咖啡or红茶。以后如果有机会再去杭州，一定还住中山。Bytheway，门口的taxi最好能合理选配一下，竟然带我们绕路，多花了15元打的费不说，还害的我们没赶上车，浪费了车票。。宾馆反馈2008年8月12日：非常感谢您选择中山国际作为您的休栖之地，也非常感谢您对我们提出的意见和建议。酒店也将会向有关部门反映出租车的问题，希望能为客人提供更为有效的服务。忠心期盼您的再次光临。', 1)\n",
      "('通过携程定的每晚240元标间，性价比真的比较差。一、酒店通道和房间的异味特别浓烈，晚上久久无法入睡。二、房间设施过于简易，淋浴的水流特别小，在家洗澡的时间等于在酒店洗头发的时间。房间没有吹风机,洗好头发后需要向客服借用(客服竟然说全酒店只有2个吹风机,已经都借出去了没有了～)三、早晨起床后，发现洗漱间的水都是热水，孩子无法忍受热水索性就没洗漱。四、早餐一般也无所谓，服务生只负责撤盘子不负责换台布，台布很脏吃的倒胃口。第二天立刻换了酒店。', 0)\n",
      "('3楼是ktv，很吵；前台服务一般，退房时还莫名其妙有了99的国外长途费，房间内没有宽带，有个adsl猫，他', 0)\n",
      "('目前景德镇最好的酒店,这次住的是358元的商务标间,房间和卫生间都比较宽敞,干净.但是国外的卫星电视节目只有NHK和CNN.两部小电梯在客人集中的时段要等很久.早餐的确很差.', 1)\n",
      "('总体上感觉还可以，但也没有什么太突出的地方，因为才过一个月多点，实在也想不起来有什么让人印象深刻点的。', 1)\n",
      "('次在此住了9晚上...房大...有蚊子..隔音太差...向可否能蚊香...酒店台居然有(上是有的)服太惰....之後反...我的房到其他..住了2天居然被房跳蚤咬了整大腿14-15苞...看著我被跳蚤咬的苞...人也什...在可!!要我事後打酒店理理....住了整整9夜...由於上海出差18天...之後就改去家店~~後...由於我的卡(不是境卡)需要打行..台值班理居然要我自行回台....服品跟房生在是很糟糕...补充点评2007年10月25日：ぇ...┬丁矫ネび...叭キび,τぃ镑,搂华ぃ,...┬丁砰ぃ,び,び,俱丁┬丁3秸,Τ８镑....ρ侣ぃ....ぃ崩滤┍!!', 0)\n",
      "('非常好，交通很方便，就在市区，娱乐也很方便。就是携程的价格不是很优惠。', 1)\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(trainset[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step5 创建Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(pretrain_model_dir)\n",
    "\n",
    "def collate_func(batch):\n",
    "    texts, labels = [], []\n",
    "    for item in batch:\n",
    "        texts.append(item[0])\n",
    "        labels.append(item[1])\n",
    "    inputs = tokenizer(texts, max_length=128, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    inputs[\"labels\"] = torch.tensor(labels)\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "trainloader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_func)\n",
    "validloader = DataLoader(validset, batch_size=64, shuffle=False, collate_fn=collate_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[ 101, 6983, 2421,  ...,    0,    0,    0],\n",
       "        [ 101, 1762,  677,  ...,  705, 2168,  102],\n",
       "        [ 101, 6821,  702,  ..., 5307, 1469,  102],\n",
       "        ...,\n",
       "        [ 101, 1765, 4415,  ...,    0,    0,    0],\n",
       "        [ 101, 2595,  817,  ...,    0,    0,    0],\n",
       "        [ 101, 6421, 6983,  ..., 2414,  855,  102]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        ...,\n",
       "        [1, 1, 1,  ..., 0, 0, 0],\n",
       "        [1, 1, 1,  ..., 0, 0, 0],\n",
       "        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "        0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1,\n",
       "        1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1])}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(enumerate(validloader))[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step6 创建模型及优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data/models/huggingface/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from torch.optim import Adam\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(pretrain_model_dir)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = Adam(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step7 训练与验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def evaluate():\n",
    "    model.eval()\n",
    "    acc_num = 0\n",
    "    with torch.inference_mode():\n",
    "        for batch in validloader:\n",
    "            if torch.cuda.is_available():\n",
    "                batch = {k: v.cuda() for k, v in batch.items()}\n",
    "            output = model(**batch)\n",
    "            pred = torch.argmax(output.logits, dim=-1)\n",
    "            acc_num += (pred.long() == batch[\"labels\"].long()).float().sum()\n",
    "    return acc_num / len(validset)\n",
    "\n",
    "def train(epoch=3, log_step=100):\n",
    "    global_step = 0\n",
    "    for ep in range(epoch):\n",
    "        model.train()\n",
    "        start = time.time()\n",
    "        for batch in trainloader:\n",
    "            if torch.cuda.is_available():\n",
    "                batch = {k: v.cuda() for k, v in batch.items()}\n",
    "            optimizer.zero_grad()\n",
    "            output = model(**batch)\n",
    "            loss = output.loss.mean()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if global_step % log_step == 0:\n",
    "                print(f\"ep: {ep}, global_step: {global_step}, loss: {loss.item()}\")\n",
    "            global_step += 1\n",
    "        acc = evaluate()\n",
    "        print(f\"ep: {ep}, acc: {acc}, time: {time.time() - start}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step8 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ep: 0, global_step: 0, loss: 0.6320370435714722\n",
      "ep: 0, global_step: 100, loss: 0.2564571797847748\n",
      "ep: 0, global_step: 200, loss: 0.1607816219329834\n",
      "ep: 0, acc: 0.8917525410652161, time: 24.14542531967163\n",
      "ep: 1, global_step: 300, loss: 0.17897820472717285\n",
      "ep: 1, global_step: 400, loss: 0.2679259181022644\n",
      "ep: 1, acc: 0.907216489315033, time: 22.623146772384644\n",
      "ep: 2, global_step: 500, loss: 0.09468485414981842\n",
      "ep: 2, global_step: 600, loss: 0.0822243019938469\n",
      "ep: 2, acc: 0.8891752362251282, time: 22.671192407608032\n"
     ]
    }
   ],
   "source": [
    "train()  # 1m3.6s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step9 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入：我觉得这家酒店不错，饭很好吃！\n",
      "模型预测结果:好评！\n"
     ]
    }
   ],
   "source": [
    "sen = \"我觉得这家酒店不错，饭很好吃！\"\n",
    "id2_label = {0: \"差评！\", 1: \"好评！\"}\n",
    "model.eval()\n",
    "with torch.inference_mode():\n",
    "    inputs = tokenizer(sen, return_tensors=\"pt\")\n",
    "    inputs = {k: v.cuda() for k, v in inputs.items()}\n",
    "    logits = model(**inputs).logits\n",
    "    pred = torch.argmax(logits, dim=-1)\n",
    "    print(f\"输入：{sen}\\n模型预测结果:{id2_label.get(pred.item())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "model.config.id2label = id2_label\n",
    "pipe = pipeline(\"text-classification\", model=model, tokenizer=tokenizer, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': '好评！', 'score': 0.9895836114883423}]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
